package org.funny.nn.som.demo;

import org.funny.nn.som.algorithm.Som;

import javax.swing.*;
import java.awt.*;
import java.util.*;
import java.util.List;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;

/**
 * 尝试将数据分类
 * @author: LinLW
 */
public class ClassificationDemo  extends JFrame {
    public ClassificationDemo() {
        this.setTitle("SOM 解决分类演示");// 设置窗体的标题
        ClassificationDemo.DrawingPanel panel = new ClassificationDemo.DrawingPanel();
        this.getContentPane().add(panel);
        this.setSize(1200, 600);// 设置窗体的大小
//        this.setLocation(260, 150);// 设置窗体的初始位置
        this.setResizable(false);
        this.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
        // this.setBounds(260, 150, 300, 300);//设置窗体的坐标、大小，相当于前面2行代码

        List<City> cities=TspDemo.readCities();
        this.cities=cities;
        float[][] inputMartix=new float[cities.size()][2];
        for(int i=0;i<cities.size();i++){
            City c=cities.get(i);
            inputMartix[i]=new float[]{(float)c.getX(), (float)c.getY()};
        }
        TspDemo.normalize(inputMartix);//归一化

        BlockingQueue<float[][][]> queue=new LinkedBlockingQueue<>(1);
        //计算线程
        Thread t=new Thread(()->{
            Som som=new Som();
            ClassificationDemo.this.som=som;
            som.setIterations(100); //设置运行200轮
            som.setDecayedRate(1.0F-(float) Math.E/100);//设置衰减系数，查多200轮的时候，邻居半径衰减到 1 以下
            som.setLearningRate(0.618F);//学习率，一开始的学习率不宜太高，否则竞争学习可能会过度。
            som.setNetworkSize((int)Math.ceil(Math.sqrt(5*Math.sqrt( inputMartix.length))));// 按建议改为 根号(5 根号(数量))。
            som.setNeighborRadix(som.getNetworkSize()); //邻居半径
            som.setInput(inputMartix);

            //设置距离计算方式，地图距离。
            som.setDistanceCalculator((a, b)->(float)Math.sqrt((a[0]-b[0])*(a[0]-b[0])+(a[1]-b[1])*(a[1]-b[1])));

            //设置一个监视学习过程的钩子。
            som.setDataWatcher((iter,finish,network)-> {
                try {
                    if(finish==34) {
                        queue.put((float[][][])network);
                        ClassificationDemo.this.iter=iter;
                        ClassificationDemo.this.finish=finish;
                    }
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
            });

            som.train4classifiy();//开始学习。


        });
        t.start();
        this.input=inputMartix;
        //控制动画刷新线程  两个线程用blocking queue来协调。
        Thread t2=new Thread(()->{
            while(true){
                try {
                    ClassificationDemo.this.network=queue.poll(1000,TimeUnit.MILLISECONDS);
                    if(ClassificationDemo.this.network==null){
                        break;
                    }
                    panel.repaint();
                    //机器计算太块，每秒控制只能显示10帧，否则，肉眼看不过来。
                    //ClassificationDemo.5秒给肉眼反应时间
                    if(ClassificationDemo.this.finish==0&& ClassificationDemo.this.iter==0){
                        Thread.sleep(1500);
                    }
                    Thread.sleep(100);
                } catch (Exception e) {
                    e.printStackTrace();
                    break;
                }
            }
        });
        t2.start();
    }
    /* 几个用于计算线程 与 绘图线程件传值 */
    private float[][] input;
    private float[][][] network;
    private List<City> cities;
    private Som som;
    private int iter;
    private int finish;

    public class DrawingPanel extends JPanel{
        public void paintComponent(Graphics g) {
            super.paintComponent(g);


            for(int i=0;i<network.length;i++){
                float[][] dots=network[i];
                for(int j=0;j<dots.length;j++) {
                    float[] dot=dots[j];
                    drawNetworkNode(i,j,dot, g); //绘制当前神经网络层节点
                }
            }

            //找出每个 城市最近的城市。
            Map<String,List<Integer>> group=new TreeMap<>();

            for(int i=0;i<input.length;i++){
                float[] city=input[i];
                int[] pos = som.selectClosest(network,city);
                group.computeIfAbsent(pos[0]+"-"+pos[1],x->new ArrayList<>()).add(i);
            }

            int i=0;
            for(Map.Entry<String,List<Integer>> entry:group.entrySet()){
                String key=entry.getKey();
                List <Integer> values=entry.getValue();
                String[] spl=key.split("-");
                int x=Integer.parseInt(spl[0]);
                int y=Integer.parseInt(spl[1]);
                drawLine(i,x,y,g);
                int m=0;
                for(Integer ci:values){
                    drawCity(i,m,input[ci],cities.get(ci).getCity(),g);
                    m++;
                }
                i++;
            }

            g.setColor(Color.GRAY);
            g.drawString("第"+ClassificationDemo.this.iter+"轮",1120,40);

        }

        private void drawCity(int groupIndex, int cityIndex, float[] dot, String city, Graphics g) {
            Color c=new Color(dot[0],dot[1],0.5F);
            int lineX=10+51*groupIndex;
            int lineY=310+21*cityIndex;
            g.setColor(c);
            g.fillRect(lineX,lineY,20,20);
            g.setColor(Color.GRAY);
            g.drawString(city,lineX+21,lineY+16);
        }

        private void drawLine(int groupIndex, int x, int y, Graphics g) {
            float[] dot=network[x][y];
            Color c=new Color(dot[0],dot[1],0.5F);

            int lineX=10+51*groupIndex;
            int lineY = 300;
            g.setColor(c);
            g.fillRect(lineX,lineY,48,5);
            g.setColor(Color.BLACK);
            g.drawLine(lineX+24,lineY, x*41+480, y*41+30);


            
        }

        private void drawNetworkNode(int i,int j,float[] dot,Graphics g) {
            Color c=new Color(dot[0],dot[1],0.5F);
            int x=i*41+460;
            int y=j*41+10;
            g.setColor(c);
            g.fillRect(x,y,40,40);
        }
    }

    public static void main(String[] args){
        ClassificationDemo demo=new ClassificationDemo ();
        demo.setVisible(true);
    }

}